# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under both the MIT license found in the
# LICENSE-MIT file in the root directory of this source tree and the Apache
# License, Version 2.0 found in the LICENSE-APACHE file in the root directory
# of this source tree.

import os
import subprocess
import sys

import dep_file_utils

DEP_PREFIX = "Note: including file:"
# output_path -> path to write the dep field to
# cmd_args -> command to be run to get dependencies from compiler
# source_file -> Path to the file we're generating the dep file for. We need this since
# when generating dependencies for a file using show_headers, the output does not include
# the file itself, so we need the path to add it manually
def process_show_includes_dep_file(output_path, cmd_args, input_file):
    out = subprocess.run(cmd_args, stdout=subprocess.PIPE, encoding="utf-8")
    if out.returncode == 0:
        rewrite_dep_file_for_msvc(out.stdout, output_path, input_file)
    else:
        parse_stdout_error_output(out.stdout)
    sys.exit(out.returncode)


def rewrite_dep_file_for_msvc(output, dst_path, input_file):
    """
    Convert stdout generated by MSVC to dep file. This will be a mix of output like:

    file.cpp
    Note: including file: path/to/dep1.h
    Note: including file:  path/to/dep2.h
    Note: including file:   path/to/dep3.h
    error: this is an error!

    and we want to get:

    path/to/dep1.h
    path/to/dep2.h
    path/to/dep3.h

    """
    here = os.getcwd() + os.sep
    deps = []
    # First line is the name of the file we're generating deps for.
    # We manually include it later so let's ignore it.
    lines = output.splitlines()[1:]
    for line in lines:
        if DEP_PREFIX in line:
            line = line.replace(DEP_PREFIX, "").strip()
            deps.append(line)
        else:
            print(line, file=sys.stderr)
    deps.append(input_file)
    normalized_deps = dep_file_utils.normalize_deps(deps, here)

    with open(dst_path, "w") as f:
        for dep in normalized_deps:
            f.write(dep)
            f.write("\n")


def parse_stdout_error_output(output):
    lines = output.splitlines()[1:]
    for line in lines:
        if DEP_PREFIX not in line:
            print(line, file=sys.stderr)
